from typing import Dict, List, Optional

import json
import importlib_resources as resources

from helm.common.hierarchical_logger import hlog
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.window_services.window_service import WindowService
from helm.benchmark.window_services.window_service_factory import WindowServiceFactory
from helm.benchmark.window_services.tokenizer_service import TokenizerService
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat


EFFICIENCY_DATA_PACKAGE: str = "helm.benchmark.efficiency_data"

INFERENCE_IDEALIZED_RUNTIMES_JSON_FILENAME: str = "inference_idealized_runtimes.json"
INFERENCE_DENOISED_RUNTIMES_JSON_FILENAME: str = "inference_denoised_runtimes.json"
TRAINING_EFFICIENCY_JSON_FILENAME: str = "training_efficiency.json"


# TODO Actually make this work like a Metric. The current state is just trying to split
# it out of other Metrics to make refactoring easier.
class EfficiencyMetric:
    def __init__(self):
        # For Efficiency metrics:
        # The `inference_efficiency.json` file contains a `runtime_per_output_token` value
        # (the estimated runtime of generating one output token) and a
        # `runtime_for_prompt_tokens` dict (a mapping from various num_prompt_tokens values to
        # the estimated runtime of encoding a prompt with that many tokens).
        # For example:
        # "openai/davinci": {
        #   "runtime_per_output_token": 0.080,
        #   "runtime_for_prompt_tokens": {
        #     "1": 0.016,
        #     "16": 0.018,
        #     "32": 0.020,
        #     ...
        #
        # These runtimes are generated by initializing Megatron with a model of the right size,
        # obtaining end-to-end generation times for different numbers of prompt and output tokens,
        # and then fitting a linear regression model to the runtimes: the resulting slope is the
        # runtime_per_output_token, which is the processing time for generating each output token,
        # and the y-intercept is the runtime_for_prompt_tokens, with different values for different
        # num_prompt_tokens values.
        # Profiling code and logs, and code to fit the regression model is available at
        # https://github.com/stanford-crfm/benchmarking_efficiency.
        data_package = resources.files(EFFICIENCY_DATA_PACKAGE)
        with data_package.joinpath(INFERENCE_IDEALIZED_RUNTIMES_JSON_FILENAME).open("r") as f:
            self.inference_idealized_runtimes_dict = json.load(f)
        with data_package.joinpath(INFERENCE_DENOISED_RUNTIMES_JSON_FILENAME).open("r") as f:
            self.inference_denoised_runtimes_dict = json.load(f)
        # We use estimated emitted CO2 during training (in tons of CO2) as a proxy metric
        # for training efficiency. We use reported metrics where applicable, otherwise
        # we estimate them from runtime information, type and number of hardware accelerators
        # used, region, etc.
        with data_package.joinpath(TRAINING_EFFICIENCY_JSON_FILENAME).open("r") as f:
            self.training_efficiency_dict = json.load(f)

    def compute_efficiency_metrics(
        self, adapter_spec: AdapterSpec, request_state: RequestState, metric_service: MetricService
    ) -> List[Stat]:
        """Compute efficiency metrics for both inference and training.
        For inference, we record both the actual runtime and an estimated idealized runtime
        for the given request with an optimized software implementation run on A100 GPU(s),
        taking into account both the number of tokens in the prompt of the request, and the
        number of generated output tokens.
        For training, we report the estimated total metric tons of CO2 emitted to train the
        model. This is the same for each request."""
        # Compute efficiency metrics for inference.
        assert request_state.result is not None

        runtime: Optional[float] = None
        batch_size: Optional[int] = None
        # Compute efficiency metrics for inference.
        if request_state.result.request_time is not None:
            runtime = request_state.result.request_time
            batch_size = 1
        # For models that perform offline batch inference, effective runtime is batch_request_time, but also
        # record batch_size to provide nuance.
        if request_state.result.batch_request_time is not None and request_state.result.batch_size is not None:
            runtime = request_state.result.batch_request_time
            batch_size = request_state.result.batch_size

        # Compute total number of prompt and output tokens.
        # Fetch the right `Tokenizer` depending on the model defined in `AdapterSpec`
        # and calculate the number of tokens in the prompt.
        tokenizer_service: TokenizerService = metric_service
        window_service: WindowService = WindowServiceFactory.get_window_service(
            adapter_spec.model_deployment, tokenizer_service
        )

        prompt: str
        num_prompt_tokens: int
        if request_state.request.multimodal_prompt is not None:
            prompt = request_state.request.multimodal_prompt.text
            num_prompt_tokens = window_service.get_num_tokens(prompt)
        else:
            prompt = request_state.request.prompt
            num_prompt_tokens = window_service.get_num_tokens(prompt)

        # Total number of tokens in the completion.
        num_completion_tokens: int = sum([len(completion.tokens) for completion in request_state.result.completions])
        # Don't include prompt in number of generated tokens (e.g., for language modeling).
        # Assume that tokens for different completions are generated sequentially (instead of batched) when
        # computing num_output_tokens (for the purpose of runtime estimation).
        num_output_tokens: int = num_completion_tokens
        if request_state.request.echo_prompt:
            # num_prompt_tokens > num_output_tokens can happen if tokenizer doesn't round trip.
            if num_prompt_tokens <= num_output_tokens:
                num_output_tokens -= num_prompt_tokens
            else:
                hlog(
                    f"WARNING: num_prompt_tokens ({num_prompt_tokens}) > num_output_tokens ({num_output_tokens}) "
                    f"for prompt: {prompt}"
                )
                num_output_tokens = 0

        idealized_runtime: Optional[float] = _compute_estimated_time_from_prompt_size_and_num_output_tokens(
            request_state, self.inference_idealized_runtimes_dict, num_prompt_tokens, num_output_tokens
        )

        denoised_runtime: Optional[float] = _compute_estimated_time_from_prompt_size_and_num_output_tokens(
            request_state, self.inference_denoised_runtimes_dict, num_prompt_tokens, num_output_tokens
        )
        # Denoised runtime for offline models is just runtime.
        # We divide by batch_size to get approximate per-input runtime.
        if runtime is not None and request_state.result.batch_size is not None:
            denoised_runtime = runtime / request_state.result.batch_size

        # Compute efficiency metrics for training.
        training_co2_cost: Optional[float]
        if request_state.request.model_deployment in self.training_efficiency_dict["carbon"]:
            training_co2_cost = self.training_efficiency_dict["carbon"][request_state.request.model_deployment]["value"]
        else:
            training_co2_cost = None

        training_energy_cost: Optional[float]
        if request_state.request.model_deployment in self.training_efficiency_dict["energy"]:
            training_energy_cost = self.training_efficiency_dict["energy"][request_state.request.model_deployment][
                "value"
            ]
        else:
            training_energy_cost = None

        stats = [
            Stat(MetricName("num_prompt_tokens")).add(num_prompt_tokens),
            Stat(MetricName("num_completion_tokens")).add(num_completion_tokens),
            Stat(MetricName("num_output_tokens")).add(num_output_tokens),
            Stat(MetricName("training_co2_cost")).add(training_co2_cost),
            Stat(MetricName("training_energy_cost")).add(training_energy_cost),
        ]
        if runtime is not None:
            stats.append(Stat(MetricName("inference_runtime")).add(runtime))
        if batch_size is not None:
            stats.append(Stat(MetricName("batch_size")).add(batch_size))
        if denoised_runtime is not None:
            stats.append(Stat(MetricName("inference_denoised_runtime")).add(denoised_runtime))
        if idealized_runtime is not None:
            stats.append(Stat(MetricName("inference_idealized_runtime")).add(idealized_runtime))
        return stats


def _compute_estimated_time_from_prompt_size_and_num_output_tokens(
    request_state: RequestState,
    inference_runtimes_dict: Dict[str, Dict],
    num_prompt_tokens: int,
    num_output_tokens: int,
) -> Optional[float]:
    estimated_runtime: Optional[float]
    if request_state.request.model_deployment in inference_runtimes_dict:
        inference_runtimes_dict_for_model = inference_runtimes_dict[request_state.request.model_deployment]
        runtime_per_output_token: float = inference_runtimes_dict_for_model["runtime_per_output_token"]
        raw_runtimes_for_prompt_tokens: Dict[str, float] = inference_runtimes_dict_for_model[
            "runtime_for_prompt_tokens"
        ]
        runtimes_for_prompt_tokens: Dict[int, float] = {int(k): v for (k, v) in raw_runtimes_for_prompt_tokens.items()}

        runtime_for_prompt_tokens: Optional[float] = None
        largest_num_tokens_in_efficiency_dict: int = max(runtimes_for_prompt_tokens.keys())
        # Find the smallest num_prompt_tokens larger than the number of tokens in the given prompt,
        # then scale runtime in dict by (num_prompt_tokens / key) to get more accurate estimate: we
        # assume that we can encode the prompt at the same throughput as the smallest key larger than
        # num_prompt_tokens, and number of compute operations scales linearly with num_prompt_tokens.
        for key in sorted(runtimes_for_prompt_tokens.keys()):
            if num_prompt_tokens <= key:
                runtime_for_prompt_tokens = runtimes_for_prompt_tokens[key] * (num_prompt_tokens / key)
                break
        # If number of tokens in the prompt exceeds the largest key in the efficiency dict, then
        # estimate the prompt encoding time by linearly scaling up the runtime for the largest
        # key (this is reasonably accurate under certain simplifying assumptions).
        if runtime_for_prompt_tokens is None:
            runtime_for_prompt_tokens = runtimes_for_prompt_tokens[largest_num_tokens_in_efficiency_dict] * (
                num_prompt_tokens / largest_num_tokens_in_efficiency_dict
            )
        overhead: Optional[float] = inference_runtimes_dict_for_model.get("overhead")

        # Idealized runtime is sum of the runtime of encoding the input tokens, the runtime of
        # generating `num_output_tokens` (`runtime_per_output_token` * (`num_output_tokens` - 1))
        # if number of output tokens is greater than 0, otherwise just `runtime_for_prompt_tokens`,
        # and the overhead if available.
        estimated_runtime = runtime_for_prompt_tokens
        if num_output_tokens > 0:
            estimated_runtime += runtime_per_output_token * (num_output_tokens - 1)
        # Add overhead if it is available.
        if overhead is not None:
            estimated_runtime += overhead
    else:
        estimated_runtime = None

    return estimated_runtime
